{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "10c559a0-e178-412b-9d80-93bbd12cb5ba",
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "8bfbfacd-af87-4876-982f-8cc27174bdd2",
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import defaultdict\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torchvision\n",
    "from torchvision.transforms import ToTensor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "fc36f74d-0370-440b-bd44-e01f0fa881ad",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set the torch seed for deterministic behavior\n",
    "torch.manual_seed(0);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "ff7a2fe8-2437-464c-ba22-3d520642e5ac",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_data = torchvision.datasets.MNIST(\n",
    "    '/tmp',\n",
    "    train=True,\n",
    "    download=True,\n",
    "    transform=ToTensor()\n",
    ")\n",
    "train_loader = torch.utils.data.DataLoader(\n",
    "    train_data,\n",
    "    batch_size=64,\n",
    "    shuffle=True,\n",
    "    num_workers=4\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "ce59c008-eca0-4e22-8b22-29f81a236092",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_data = torchvision.datasets.MNIST(\n",
    "    '/tmp',\n",
    "    train=True,\n",
    "    download=True,\n",
    "    transform=ToTensor()\n",
    ")\n",
    "test_loader = torch.utils.data.DataLoader(\n",
    "    test_data,\n",
    "    batch_size=64,\n",
    "    shuffle=True,\n",
    "    num_workers=4\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "c02fde30-bf01-41d6-abe5-b953b2d3ce6d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<matplotlib.image.AxesImage at 0x7f10927e0a90>"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAN80lEQVR4nO3df6hcdXrH8c+ncf3DrBpTMYasNhuRWBWbLRqLSl2RrD9QNOqWDVgsBrN/GHChhEr6xyolEuqP0qAsuYu6sWyzLqgYZVkVo6ZFCF5j1JjU1YrdjV6SSozG+KtJnv5xT+Su3vnOzcyZOZP7vF9wmZnzzJnzcLife87Md879OiIEYPL7k6YbANAfhB1IgrADSRB2IAnCDiRxRD83ZpuP/oEeiwiPt7yrI7vtS22/aftt27d281oAesudjrPbniLpd5IWSNou6SVJiyJia2EdjuxAj/XiyD5f0tsR8U5EfCnpV5Ku6uL1APRQN2GfJekPYx5vr5b9EdtLbA/bHu5iWwC61M0HdOOdKnzjND0ihiQNSZzGA03q5si+XdJJYx5/R9L73bUDoFe6CftLkk61/V3bR0r6kaR19bQFoG4dn8ZHxD7bSyU9JWmKpAci4o3aOgNQq46H3jraGO/ZgZ7ryZdqABw+CDuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUii4ymbcXiYMmVKsX7sscf2dPtLly5tWTvqqKOK686dO7dYv/nmm4v1u+66q2Vt0aJFxXU///zzYn3lypXF+u23316sN6GrsNt+V9IeSfsl7YuIs+toCkD96jiyXxQRH9TwOgB6iPfsQBLdhj0kPW37ZdtLxnuC7SW2h20Pd7ktAF3o9jT+/Ih43/YJkp6x/V8RsWHsEyJiSNKQJNmOLrcHoENdHdkj4v3qdqekxyTNr6MpAPXrOOy2p9o++uB9ST+QtKWuxgDUq5vT+BmSHrN98HX+PSJ+W0tXk8zJJ59crB955JHF+nnnnVesX3DBBS1r06ZNK6577bXXFutN2r59e7G+atWqYn3hwoUta3v27Cmu++qrrxbrL7zwQrE+iDoOe0S8I+kvauwFQA8x9AYkQdiBJAg7kARhB5Ig7EASjujfl9om6zfo5s2bV6yvX7++WO/1ZaaD6sCBA8X6jTfeWKx/8sknHW97ZGSkWP/www+L9TfffLPjbfdaRHi85RzZgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJxtlrMH369GJ948aNxfqcOXPqbKdW7XrfvXt3sX7RRRe1rH355ZfFdbN+/6BbjLMDyRF2IAnCDiRB2IEkCDuQBGEHkiDsQBJM2VyDXbt2FevLli0r1q+44opi/ZVXXinW2/1L5ZLNmzcX6wsWLCjW9+7dW6yfccYZLWu33HJLcV3UiyM7kARhB5Ig7EAShB1IgrADSRB2IAnCDiTB9ewD4JhjjinW200vvHr16pa1xYsXF9e9/vrri/W1a9cW6xg8HV/PbvsB2zttbxmzbLrtZ2y/Vd0eV2ezAOo3kdP4X0i69GvLbpX0bEScKunZ6jGAAdY27BGxQdLXvw96laQ11f01kq6uty0Adev0u/EzImJEkiJixPYJrZ5oe4mkJR1uB0BNen4hTEQMSRqS+IAOaFKnQ287bM+UpOp2Z30tAeiFTsO+TtIN1f0bJD1eTzsAeqXtabzttZK+L+l429sl/VTSSkm/tr1Y0u8l/bCXTU52H3/8cVfrf/TRRx2ve9NNNxXrDz/8cLHebo51DI62YY+IRS1KF9fcC4Ae4uuyQBKEHUiCsANJEHYgCcIOJMElrpPA1KlTW9aeeOKJ4roXXnhhsX7ZZZcV608//XSxjv5jymYgOcIOJEHYgSQIO5AEYQeSIOxAEoQdSIJx9knulFNOKdY3bdpUrO/evbtYf+6554r14eHhlrX77ruvuG4/fzcnE8bZgeQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJxtmTW7hwYbH+4IMPFutHH310x9tevnx5sf7QQw8V6yMjIx1vezJjnB1IjrADSRB2IAnCDiRB2IEkCDuQBGEHkmCcHUVnnnlmsX7PPfcU6xdf3Plkv6tXry7WV6xYUay/9957HW/7cNbxOLvtB2zvtL1lzLLbbL9ne3P1c3mdzQKo30RO438h6dJxlv9LRMyrfn5Tb1sA6tY27BGxQdKuPvQCoIe6+YBuqe3XqtP841o9yfYS28O2W/8zMgA912nYfybpFEnzJI1IurvVEyNiKCLOjoizO9wWgBp0FPaI2BER+yPigKSfS5pfb1sA6tZR2G3PHPNwoaQtrZ4LYDC0HWe3vVbS9yUdL2mHpJ9Wj+dJCknvSvpxRLS9uJhx9sln2rRpxfqVV17ZstbuWnl73OHir6xfv75YX7BgQbE+WbUaZz9iAisuGmfx/V13BKCv+LoskARhB5Ig7EAShB1IgrADSXCJKxrzxRdfFOtHHFEeLNq3b1+xfskll7SsPf/888V1D2f8K2kgOcIOJEHYgSQIO5AEYQeSIOxAEoQdSKLtVW/I7ayzzirWr7vuumL9nHPOaVlrN47eztatW4v1DRs2dPX6kw1HdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgnH2SW7u3LnF+tKlS4v1a665plg/8cQTD7mnidq/f3+xPjJS/u/lBw4cqLOdwx5HdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgnH2w0C7sexFi8abaHdUu3H02bNnd9JSLYaHh4v1FStWFOvr1q2rs51Jr+2R3fZJtp+zvc32G7ZvqZZPt/2M7beq2+N63y6ATk3kNH6fpL+PiD+X9FeSbrZ9uqRbJT0bEadKerZ6DGBAtQ17RIxExKbq/h5J2yTNknSVpDXV09ZIurpHPQKowSG9Z7c9W9L3JG2UNCMiRqTRPwi2T2ixzhJJS7rsE0CXJhx229+W9Iikn0TEx/a4c8d9Q0QMSRqqXoOJHYGGTGjozfa3NBr0X0bEo9XiHbZnVvWZknb2pkUAdWh7ZPfoIfx+Sdsi4p4xpXWSbpC0srp9vCcdTgIzZswo1k8//fRi/d577y3WTzvttEPuqS4bN24s1u+8886WtccfL//KcIlqvSZyGn++pL+V9LrtzdWy5RoN+a9tL5b0e0k/7EmHAGrRNuwR8Z+SWr1Bv7jedgD0Cl+XBZIg7EAShB1IgrADSRB2IAkucZ2g6dOnt6ytXr26uO68efOK9Tlz5nTSUi1efPHFYv3uu+8u1p966qli/bPPPjvkntAbHNmBJAg7kARhB5Ig7EAShB1IgrADSRB2IIk04+znnntusb5s2bJiff78+S1rs2bN6qinunz66acta6tWrSque8cddxTre/fu7agnDB6O7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQRJpx9oULF3ZV78bWrVuL9SeffLJY37dvX7FeuuZ89+7dxXWRB0d2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUjCEVF+gn2SpIcknSjpgKShiPhX27dJuknS/1ZPXR4Rv2nzWuWNAehaRIw76/JEwj5T0syI2GT7aEkvS7pa0t9I+iQi7ppoE4Qd6L1WYZ/I/Owjkkaq+3tsb5PU7L9mAXDIDuk9u+3Zkr4naWO1aKnt12w/YPu4FusssT1se7i7VgF0o+1p/FdPtL8t6QVJKyLiUdszJH0gKST9k0ZP9W9s8xqcxgM91vF7dkmy/S1JT0p6KiLuGac+W9KTEXFmm9ch7ECPtQp729N425Z0v6RtY4NefXB30EJJW7ptEkDvTOTT+Ask/Yek1zU69CZJyyUtkjRPo6fx70r6cfVhXum1OLIDPdbVaXxdCDvQex2fxgOYHAg7kARhB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJ9HvK5g8k/c+Yx8dXywbRoPY2qH1J9NapOnv7s1aFvl7P/o2N28MRcXZjDRQMam+D2pdEb53qV2+cxgNJEHYgiabDPtTw9ksGtbdB7Uuit071pbdG37MD6J+mj+wA+oSwA0k0Enbbl9p+0/bbtm9toodWbL9r+3Xbm5uen66aQ2+n7S1jlk23/Yztt6rbcefYa6i322y/V+27zbYvb6i3k2w/Z3ub7Tds31Itb3TfFfrqy37r+3t221Mk/U7SAknbJb0kaVFEbO1rIy3YflfS2RHR+BcwbP+1pE8kPXRwai3b/yxpV0SsrP5QHhcR/zAgvd2mQ5zGu0e9tZpm/O/U4L6rc/rzTjRxZJ8v6e2IeCcivpT0K0lXNdDHwIuIDZJ2fW3xVZLWVPfXaPSXpe9a9DYQImIkIjZV9/dIOjjNeKP7rtBXXzQR9lmS/jDm8XYN1nzvIelp2y/bXtJ0M+OYcXCarer2hIb7+bq203j309emGR+YfdfJ9OfdaiLs401NM0jjf+dHxF9KukzSzdXpKibmZ5JO0egcgCOS7m6ymWqa8Uck/SQiPm6yl7HG6asv+62JsG+XdNKYx9+R9H4DfYwrIt6vbndKekyjbzsGyY6DM+hWtzsb7ucrEbEjIvZHxAFJP1eD+66aZvwRSb+MiEerxY3vu/H66td+ayLsL0k61fZ3bR8p6UeS1jXQxzfYnlp9cCLbUyX9QIM3FfU6STdU92+Q9HiDvfyRQZnGu9U042p43zU+/XlE9P1H0uUa/UT+vyX9YxM9tOhrjqRXq583mu5N0lqNntb9n0bPiBZL+lNJz0p6q7qdPkC9/ZtGp/Z+TaPBmtlQbxdo9K3ha5I2Vz+XN73vCn31Zb/xdVkgCb5BByRB2IEkCDuQBGEHkiDsQBKEHUiCsANJ/D+f1mbt6t55/AAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.imshow(test_data[0][0].numpy()[0,:,:], cmap='gray')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "fe0f11ad-8ef4-418e-89f0-010413413505",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<matplotlib.image.AxesImage at 0x7f1092744df0>"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAOE0lEQVR4nO3dcYxV5ZnH8d8jLUalENSIE9HabTDZptFBkJDYrKxNG4sm0JiuEOOw2SZDYknQNKZqRyGpGxujNGoicaqkWFmhihZs1qWGIbobk8YRWcWyrdRQHJkwokaGmEiFZ/+YQzPinPcM955zz4Xn+0km997zzLnn8To/zrn3Pee+5u4CcOo7re4GALQGYQeCIOxAEIQdCIKwA0F8qZUbMzM++gcq5u421vKm9uxmdo2Z/cnMdpvZ7c08F4BqWaPj7GY2QdKfJX1H0oCkVyUtdvc/JtZhzw5UrIo9+xxJu939HXc/LGm9pAVNPB+ACjUT9gskvTvq8UC27HPMrNvM+s2sv4ltAWhSMx/QjXWo8IXDdHfvldQrcRgP1KmZPfuApAtHPZ4uaV9z7QCoSjNhf1XSDDP7mplNlLRI0uZy2gJQtoYP4939MzNbJmmLpAmS1rj7W6V1BqBUDQ+9NbQx3rMDlavkpBoAJw/CDgRB2IEgCDsQBGEHgiDsQBCEHQiCsANBEHYgCMIOBEHYgSAIOxAEYQeCIOxAEIQdCIKwA0EQdiAIwg4EQdiBIAg7EARhB4Jo6ZTNOPXMmjUrWV+2bFluraurK7nuE088kaw//PDDyfr27duT9WjYswNBEHYgCMIOBEHYgSAIOxAEYQeCIOxAEMziiqTOzs5kva+vL1mfPHlyid183scff5ysn3POOZVtu53lzeLa1Ek1ZrZH0rCkI5I+c/fZzTwfgOqUcQbdP7v7gRKeB0CFeM8OBNFs2F3S783sNTPrHusXzKzbzPrNrL/JbQFoQrOH8Ve6+z4zO0/Si2b2f+7+8uhfcPdeSb0SH9ABdWpqz+7u+7LbIUnPSZpTRlMAytdw2M3sLDP7yrH7kr4raWdZjQEoVzOH8dMkPWdmx57nP9z9v0rpCi0zZ076YGzjxo3J+pQpU5L11Hkcw8PDyXUPHz6crBeNo8+dOze3VnSte9G2T0YNh93d35F0WYm9AKgQQ29AEIQdCIKwA0EQdiAIwg4EwSWup4Azzzwzt3b55Zcn133yySeT9enTpyfr2dBrrtTfV9Hw13333Zesr1+/PllP9dbT05Nc9957703W21neJa7s2YEgCDsQBGEHgiDsQBCEHQiCsANBEHYgCKZsPgU8+uijubXFixe3sJMTU3QOwKRJk5L1l156KVmfN29ebu3SSy9NrnsqYs8OBEHYgSAIOxAEYQeCIOxAEIQdCIKwA0Ewzn4SmDVrVrJ+7bXX5taKrjcvUjSW/fzzzyfr999/f25t3759yXVff/31ZP2jjz5K1q+++urcWrOvy8mIPTsQBGEHgiDsQBCEHQiCsANBEHYgCMIOBMH3xreBzs7OZL2vry9Znzx5csPbfuGFF5L1ouvhr7rqqmQ9dd34Y489llz3/fffT9aLHDlyJLf2ySefJNct+u8q+s77OjX8vfFmtsbMhsxs56hlZ5vZi2b2dnY7tcxmAZRvPIfxv5J0zXHLbpe01d1nSNqaPQbQxgrD7u4vS/rwuMULJK3N7q+VtLDctgCUrdFz46e5+6AkufugmZ2X94tm1i2pu8HtAChJ5RfCuHuvpF6JD+iAOjU69LbfzDokKbsdKq8lAFVoNOybJS3J7i+RtKmcdgBUpXCc3cyekjRP0rmS9ktaIem3kn4j6SJJeyX9wN2P/xBvrOcKeRh/ySWXJOsrVqxI1hctWpSsHzhwILc2ODiYXPeee+5J1p955plkvZ2lxtmL/u43bNiQrN94440N9dQKeePshe/Z3T3vrIpvN9URgJbidFkgCMIOBEHYgSAIOxAEYQeC4KukS3D66acn66mvU5ak+fPnJ+vDw8PJeldXV26tv78/ue4ZZ5yRrEd10UUX1d1C6dizA0EQdiAIwg4EQdiBIAg7EARhB4Ig7EAQjLOXYObMmcl60Th6kQULFiTrRdMqAxJ7diAMwg4EQdiBIAg7EARhB4Ig7EAQhB0IgnH2EqxatSpZNxvzm33/rmicnHH0xpx2Wv6+7OjRoy3spD2wZweCIOxAEIQdCIKwA0EQdiAIwg4EQdiBIBhnH6frrrsut9bZ2Zlct2h64M2bNzfSEgqkxtKL/p/s2LGj5G7qV7hnN7M1ZjZkZjtHLVtpZu+Z2Y7sp7lvZwBQufEcxv9K0jVjLP+Fu3dmP/9ZblsAylYYdnd/WdKHLegFQIWa+YBumZm9kR3mT837JTPrNrN+M0tPOgagUo2GfbWkr0vqlDQo6YG8X3T3Xnef7e6zG9wWgBI0FHZ33+/uR9z9qKRfSppTblsAytZQ2M2sY9TD70vamfe7ANpD4Ti7mT0laZ6kc81sQNIKSfPMrFOSS9ojaWl1LbaH1DzmEydOTK47NDSUrG/YsKGhnk51RfPer1y5suHn7uvrS9bvuOOOhp+7XRWG3d0Xj7H48Qp6AVAhTpcFgiDsQBCEHQiCsANBEHYgCC5xbYFPP/00WR8cHGxRJ+2laGitp6cnWb/tttuS9YGBgdzaAw/knvQpSTp06FCyfjJizw4EQdiBIAg7EARhB4Ig7EAQhB0IgrADQTDO3gKRvyo69TXbRePkN9xwQ7K+adOmZP36669P1qNhzw4EQdiBIAg7EARhB4Ig7EAQhB0IgrADQTDOPk5m1lBNkhYuXJisL1++vJGW2sKtt96arN911125tSlTpiTXXbduXbLe1dWVrOPz2LMDQRB2IAjCDgRB2IEgCDsQBGEHgiDsQBCMs4+TuzdUk6Tzzz8/WX/ooYeS9TVr1iTrH3zwQW5t7ty5yXVvuummZP2yyy5L1qdPn56s7927N7e2ZcuW5LqPPPJIso4TU7hnN7MLzWybme0ys7fMbHm2/Gwze9HM3s5up1bfLoBGjecw/jNJP3b3f5Q0V9KPzOwbkm6XtNXdZ0jamj0G0KYKw+7ug+6+Pbs/LGmXpAskLZC0Nvu1tZIWVtQjgBKc0Ht2M7tY0kxJf5A0zd0HpZF/EMzsvJx1uiV1N9kngCaNO+xmNknSRkm3uPvBoos/jnH3Xkm92XOkP8kCUJlxDb2Z2Zc1EvR17v5stni/mXVk9Q5JQ9W0CKAMhXt2G9mFPy5pl7uvGlXaLGmJpJ9nt+nv9Q1swoQJyfrNN9+crBd9JfLBgwdzazNmzEiu26xXXnklWd+2bVtu7e677y67HSSM5zD+Skk3SXrTzHZky+7USMh/Y2Y/lLRX0g8q6RBAKQrD7u7/IynvDfq3y20HQFU4XRYIgrADQRB2IAjCDgRB2IEgrOjyzFI3dhKfQZe6lPPpp59OrnvFFVc0te2isxWb+X+YujxWktavX5+sn8xfg32qcvcx/2DYswNBEHYgCMIOBEHYgSAIOxAEYQeCIOxAEIyzl6CjoyNZX7p0abLe09OTrDczzv7ggw8m1129enWyvnv37mQd7YdxdiA4wg4EQdiBIAg7EARhB4Ig7EAQhB0IgnF24BTDODsQHGEHgiDsQBCEHQiCsANBEHYgCMIOBFEYdjO70My2mdkuM3vLzJZny1ea2XtmtiP7mV99uwAaVXhSjZl1SOpw9+1m9hVJr0laKOlfJB1y9/vHvTFOqgEql3dSzXjmZx+UNJjdHzazXZIuKLc9AFU7offsZnaxpJmS/pAtWmZmb5jZGjObmrNOt5n1m1l/c60CaMa4z403s0mSXpL07+7+rJlNk3RAkkv6mUYO9f+t4Dk4jAcqlncYP66wm9mXJf1O0hZ3XzVG/WJJv3P3bxY8D2EHKtbwhTA28tWmj0vaNTro2Qd3x3xf0s5mmwRQnfF8Gv8tSf8t6U1JR7PFd0paLKlTI4fxeyQtzT7MSz0Xe3agYk0dxpeFsAPV43p2IDjCDgRB2IEgCDsQBGEHgiDsQBCEHQiCsANBEHYgCMIOBEHYgSAIOxAEYQeCIOxAEIVfOFmyA5L+OurxudmydtSuvbVrXxK9NarM3r6aV2jp9exf2LhZv7vPrq2BhHbtrV37kuitUa3qjcN4IAjCDgRRd9h7a95+Srv21q59SfTWqJb0Vut7dgCtU/eeHUCLEHYgiFrCbmbXmNmfzGy3md1eRw95zGyPmb2ZTUNd6/x02Rx6Q2a2c9Sys83sRTN7O7sdc469mnpri2m8E9OM1/ra1T39ecvfs5vZBEl/lvQdSQOSXpW02N3/2NJGcpjZHkmz3b32EzDM7J8kHZL0xLGptczsPkkfuvvPs38op7r7T9qkt5U6wWm8K+otb5rxf1WNr12Z0583oo49+xxJu939HXc/LGm9pAU19NH23P1lSR8et3iBpLXZ/bUa+WNpuZze2oK7D7r79uz+sKRj04zX+tol+mqJOsJ+gaR3Rz0eUHvN9+6Sfm9mr5lZd93NjGHasWm2stvzau7neIXTeLfScdOMt81r18j0582qI+xjTU3TTuN/V7r75ZK+J+lH2eEqxme1pK9rZA7AQUkP1NlMNs34Rkm3uPvBOnsZbYy+WvK61RH2AUkXjno8XdK+GvoYk7vvy26HJD2nkbcd7WT/sRl0s9uhmvv5O3ff7+5H3P2opF+qxtcum2Z8o6R17v5strj2126svlr1utUR9lclzTCzr5nZREmLJG2uoY8vMLOzsg9OZGZnSfqu2m8q6s2SlmT3l0jaVGMvn9Mu03jnTTOuml+72qc/d/eW/0iar5FP5P8i6ad19JDT1z9I+t/s5626e5P0lEYO6/6mkSOiH0o6R9JWSW9nt2e3UW+/1sjU3m9oJFgdNfX2LY28NXxD0o7sZ37dr12ir5a8bpwuCwTBGXRAEIQdCIKwA0EQdiAIwg4EQdiBIAg7EMT/Az6wY9VChzNWAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.imshow(test_data[1][0].numpy()[0,:,:], cmap='gray')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "0160bd5f-8b12-4d9a-8f49-c5cec162357b",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_data[0][0].numpy().flatten().tofile('5.mnist.bin')\n",
    "test_data[1][0].numpy().flatten().tofile('0.mnist.bin')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "9f5edc09-bb82-48b3-a317-2f60590d7d50",
   "metadata": {},
   "outputs": [],
   "source": [
    "!mv *.mnist.bin ../src/main/resources/"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "3b426cde-d514-4ea6-98f5-a926a30bdd50",
   "metadata": {},
   "outputs": [],
   "source": [
    "class LinearMNIST(nn.Module):\n",
    " \n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.linear_stack = nn.Sequential(\n",
    "            nn.Linear(28 * 28, 512),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(512, 10),\n",
    "            nn.Softmax(dim=1)\n",
    "        )\n",
    "        \n",
    "    def forward(self, x):\n",
    "        x = torch.flatten(x, start_dim=1)\n",
    "        y = self.linear_stack(x)\n",
    "        return y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "9c89eca6-10ae-4b06-965d-26ad21cc95cd",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = LinearMNIST()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "e07cb9bc-0285-4dbd-889f-34263f5639a9",
   "metadata": {},
   "outputs": [],
   "source": [
    "learning_rate = 1e-3\n",
    "batch_size = 64\n",
    "epochs = 5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "e7546c8e-f886-48e9-b3f8-07f20433065d",
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_loop(dataloader, model, loss_fn, optimizer):\n",
    "    size = len(dataloader.dataset)\n",
    "    for batch, (X, y) in enumerate(dataloader):\n",
    "        # Compute prediction and loss\n",
    "        pred = model(X)\n",
    "        loss = loss_fn(pred, y)\n",
    "\n",
    "        # Backpropagation\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        if batch % 100 == 0:\n",
    "            loss, current = loss.item(), batch * len(X)\n",
    "            print(f\"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]\")\n",
    "\n",
    "\n",
    "def test_loop(dataloader, model, loss_fn):\n",
    "    size = len(dataloader.dataset)\n",
    "    test_loss, correct = 0, 0\n",
    "\n",
    "    with torch.no_grad():\n",
    "        for X, y in dataloader:\n",
    "            pred = model(X)\n",
    "            test_loss += loss_fn(pred, y).item()\n",
    "            correct += (pred.argmax(1) == y).type(torch.float).sum().item()\n",
    "\n",
    "    test_loss /= size\n",
    "    correct /= size\n",
    "    print(f\"Test Error: \\n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \\n\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "ac6e1dbd-d5e4-4dce-b47e-569801f46f20",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1\n",
      "-------------------------------\n",
      "loss: 2.300593  [    0/60000]\n",
      "loss: 1.613943  [ 6400/60000]\n",
      "loss: 1.586333  [12800/60000]\n",
      "loss: 1.590376  [19200/60000]\n",
      "loss: 1.573111  [25600/60000]\n",
      "loss: 1.608075  [32000/60000]\n",
      "loss: 1.591978  [38400/60000]\n",
      "loss: 1.513969  [44800/60000]\n",
      "loss: 1.527455  [51200/60000]\n",
      "loss: 1.580671  [57600/60000]\n",
      "Test Error: \n",
      " Accuracy: 94.1%, Avg loss: 0.023849 \n",
      "\n",
      "Epoch 2\n",
      "-------------------------------\n",
      "loss: 1.533464  [    0/60000]\n",
      "loss: 1.533781  [ 6400/60000]\n",
      "loss: 1.497141  [12800/60000]\n",
      "loss: 1.510093  [19200/60000]\n",
      "loss: 1.489342  [25600/60000]\n",
      "loss: 1.487759  [32000/60000]\n",
      "loss: 1.533878  [38400/60000]\n",
      "loss: 1.496870  [44800/60000]\n",
      "loss: 1.502997  [51200/60000]\n",
      "loss: 1.511436  [57600/60000]\n",
      "Test Error: \n",
      " Accuracy: 96.4%, Avg loss: 0.023473 \n",
      "\n",
      "Epoch 3\n",
      "-------------------------------\n",
      "loss: 1.466511  [    0/60000]\n",
      "loss: 1.516493  [ 6400/60000]\n",
      "loss: 1.507017  [12800/60000]\n",
      "loss: 1.477768  [19200/60000]\n",
      "loss: 1.498860  [25600/60000]\n",
      "loss: 1.534850  [32000/60000]\n",
      "loss: 1.475538  [38400/60000]\n",
      "loss: 1.492521  [44800/60000]\n",
      "loss: 1.506383  [51200/60000]\n",
      "loss: 1.495400  [57600/60000]\n",
      "Test Error: \n",
      " Accuracy: 97.3%, Avg loss: 0.023308 \n",
      "\n",
      "Epoch 4\n",
      "-------------------------------\n",
      "loss: 1.495229  [    0/60000]\n",
      "loss: 1.478192  [ 6400/60000]\n",
      "loss: 1.530019  [12800/60000]\n",
      "loss: 1.525730  [19200/60000]\n",
      "loss: 1.461684  [25600/60000]\n",
      "loss: 1.479893  [32000/60000]\n",
      "loss: 1.477323  [38400/60000]\n",
      "loss: 1.491534  [44800/60000]\n",
      "loss: 1.509336  [51200/60000]\n",
      "loss: 1.462319  [57600/60000]\n",
      "Test Error: \n",
      " Accuracy: 97.9%, Avg loss: 0.023203 \n",
      "\n",
      "Epoch 5\n",
      "-------------------------------\n",
      "loss: 1.473798  [    0/60000]\n",
      "loss: 1.481868  [ 6400/60000]\n",
      "loss: 1.479568  [12800/60000]\n",
      "loss: 1.489046  [19200/60000]\n",
      "loss: 1.485274  [25600/60000]\n",
      "loss: 1.503003  [32000/60000]\n",
      "loss: 1.501189  [38400/60000]\n",
      "loss: 1.486743  [44800/60000]\n",
      "loss: 1.468397  [51200/60000]\n",
      "loss: 1.476093  [57600/60000]\n",
      "Test Error: \n",
      " Accuracy: 98.2%, Avg loss: 0.023157 \n",
      "\n",
      "Epoch 6\n",
      "-------------------------------\n",
      "loss: 1.485514  [    0/60000]\n",
      "loss: 1.509447  [ 6400/60000]\n",
      "loss: 1.500317  [12800/60000]\n",
      "loss: 1.477655  [19200/60000]\n",
      "loss: 1.461433  [25600/60000]\n",
      "loss: 1.494387  [32000/60000]\n",
      "loss: 1.496834  [38400/60000]\n",
      "loss: 1.516416  [44800/60000]\n",
      "loss: 1.499406  [51200/60000]\n",
      "loss: 1.477836  [57600/60000]\n",
      "Test Error: \n",
      " Accuracy: 98.5%, Avg loss: 0.023098 \n",
      "\n",
      "Epoch 7\n",
      "-------------------------------\n",
      "loss: 1.477426  [    0/60000]\n",
      "loss: 1.469659  [ 6400/60000]\n",
      "loss: 1.463242  [12800/60000]\n",
      "loss: 1.463617  [19200/60000]\n",
      "loss: 1.497176  [25600/60000]\n",
      "loss: 1.473976  [32000/60000]\n",
      "loss: 1.530201  [38400/60000]\n",
      "loss: 1.491543  [44800/60000]\n",
      "loss: 1.461583  [51200/60000]\n",
      "loss: 1.484636  [57600/60000]\n",
      "Test Error: \n",
      " Accuracy: 98.6%, Avg loss: 0.023075 \n",
      "\n",
      "Epoch 8\n",
      "-------------------------------\n",
      "loss: 1.469275  [    0/60000]\n",
      "loss: 1.472346  [ 6400/60000]\n",
      "loss: 1.475016  [12800/60000]\n",
      "loss: 1.476902  [19200/60000]\n",
      "loss: 1.462929  [25600/60000]\n",
      "loss: 1.496112  [32000/60000]\n",
      "loss: 1.481316  [38400/60000]\n",
      "loss: 1.487520  [44800/60000]\n",
      "loss: 1.478001  [51200/60000]\n",
      "loss: 1.488994  [57600/60000]\n",
      "Test Error: \n",
      " Accuracy: 98.8%, Avg loss: 0.023051 \n",
      "\n",
      "Epoch 9\n",
      "-------------------------------\n",
      "loss: 1.463391  [    0/60000]\n",
      "loss: 1.491408  [ 6400/60000]\n",
      "loss: 1.476874  [12800/60000]\n",
      "loss: 1.463158  [19200/60000]\n",
      "loss: 1.461942  [25600/60000]\n",
      "loss: 1.461518  [32000/60000]\n",
      "loss: 1.463728  [38400/60000]\n",
      "loss: 1.479201  [44800/60000]\n",
      "loss: 1.463128  [51200/60000]\n",
      "loss: 1.461338  [57600/60000]\n",
      "Test Error: \n",
      " Accuracy: 98.9%, Avg loss: 0.023030 \n",
      "\n",
      "Epoch 10\n",
      "-------------------------------\n",
      "loss: 1.488442  [    0/60000]\n",
      "loss: 1.463154  [ 6400/60000]\n",
      "loss: 1.491115  [12800/60000]\n",
      "loss: 1.476404  [19200/60000]\n",
      "loss: 1.463899  [25600/60000]\n",
      "loss: 1.477808  [32000/60000]\n",
      "loss: 1.479558  [38400/60000]\n",
      "loss: 1.478052  [44800/60000]\n",
      "loss: 1.503852  [51200/60000]\n",
      "loss: 1.473328  [57600/60000]\n",
      "Test Error: \n",
      " Accuracy: 98.9%, Avg loss: 0.023028 \n",
      "\n",
      "Done!\n"
     ]
    }
   ],
   "source": [
    "loss_fn = nn.CrossEntropyLoss()\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)\n",
    "\n",
    "epochs = 10\n",
    "for t in range(epochs):\n",
    "    print(f\"Epoch {t+1}\\n-------------------------------\")\n",
    "    train_loop(train_loader, model, loss_fn, optimizer)\n",
    "    test_loop(test_loader, model, loss_fn)\n",
    "print(\"Done!\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "9a17a126-e5f8-4353-a4b5-e545a59ee37d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([[1.3431e-17, 4.0926e-13, 5.2853e-11, 2.7454e-05, 7.5811e-35, 9.9997e-01,\n",
       "          1.1170e-27, 4.5161e-21, 3.1156e-16, 2.5034e-17]],\n",
       "        grad_fn=<SoftmaxBackward>),\n",
       " tensor(5),\n",
       " 5)"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y_pred = model(test_data[0][0])\n",
    "y_pred, y_pred.argmax(), test_data[0][1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "f08e543d-a97b-4e0c-a44c-d8002ca74a96",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "graph(%x : Float(1:1, 28:28, 28:1, requires_grad=0, device=cpu),\n",
      "      %linear_stack.0.weight : Float(512:784, 784:1, requires_grad=1, device=cpu),\n",
      "      %linear_stack.0.bias : Float(512:1, requires_grad=1, device=cpu),\n",
      "      %linear_stack.2.weight : Float(10:512, 512:1, requires_grad=1, device=cpu),\n",
      "      %linear_stack.2.bias : Float(10:1, requires_grad=1, device=cpu)):\n",
      "  %5 : Float(1:784, 784:1, requires_grad=0, device=cpu) = onnx::Flatten[axis=1](%x) # <ipython-input-10-b50a759e7ef0>:13:0\n",
      "  %6 : Float(1:512, 512:1, requires_grad=1, device=cpu) = onnx::Gemm[alpha=1., beta=1., transB=1](%5, %linear_stack.0.weight, %linear_stack.0.bias) # /home/colin/.miniconda3/envs/pytorch/lib/python3.8/site-packages/torch/nn/functional.py:1690:0\n",
      "  %7 : Float(1:512, 512:1, requires_grad=1, device=cpu) = onnx::Relu(%6) # /home/colin/.miniconda3/envs/pytorch/lib/python3.8/site-packages/torch/nn/functional.py:1136:0\n",
      "  %8 : Float(1:10, 10:1, requires_grad=1, device=cpu) = onnx::Gemm[alpha=1., beta=1., transB=1](%7, %linear_stack.2.weight, %linear_stack.2.bias) # /home/colin/.miniconda3/envs/pytorch/lib/python3.8/site-packages/torch/nn/functional.py:1690:0\n",
      "  %output : Float(1:10, 10:1, requires_grad=1, device=cpu) = onnx::Softmax[axis=1](%8) # /home/colin/.miniconda3/envs/pytorch/lib/python3.8/site-packages/torch/nn/functional.py:1512:0\n",
      "  return (%output)\n",
      "\n"
     ]
    }
   ],
   "source": [
    "torch.onnx.export(\n",
    "    model,\n",
    "    (test_data[0][0]),\n",
    "    'linear_mnist.onnx',\n",
    "    verbose=True,\n",
    "    input_names=['x'],\n",
    "    output_names=['output'],\n",
    "    dynamic_axes={\n",
    "        'x': {0: 'batch_size',},\n",
    "        'output': {0: 'batch_size'},\n",
    "    }\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "1acb8934-1945-4905-9f49-e56844f063fc",
   "metadata": {},
   "outputs": [],
   "source": [
    "!mv linear_mnist.onnx ../src/main/resources/"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a71935db-1b83-4ecf-ad1d-8ef14bff969d",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:pytorch]",
   "language": "python",
   "name": "conda-env-pytorch-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
